-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add device arg & Ascend npu support #2536
Conversation
e739c3a
to
728f8c2
Compare
请帮忙 review,提出一些修改建议,感谢! @xingchensong @Mddct |
可以先发一下:
我们这没有npu,不好测试 |
wenet/bin/alignment.py
Outdated
parser.add_argument('--device', | ||
type=str, | ||
default="cpu", | ||
help='accelerator to use') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用choices选项限定范围
choices=["cpu", "npu", "cuda"], etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码更新是不是没有push上来?我看还是以前的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已push最新代码
ok 我整理一下,另外我们正在申请可用于社区CI的NPU机器,后续可以向社区贡献,推动wenet+昇腾的发展和维护 |
训练及推理成功的截图已更新
|
wenet/utils/common.py
Outdated
import torch_npu # noqa | ||
return True | ||
except ImportError: | ||
print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
模型的device。不需要从__inint__里构造,在init model的时候都会 to 一个device forward里边用到device可以从input里边获取 |
你们有一些 benchmark 的结果吗?比如训练的速度、推理的速度等。我们可以一起写篇文章介绍这个工作。 |
@MengqingCao hi 加一下微信把,这里沟通效率有点低 ,currycode |
这块我还在调试,我们的机器是arm的,openfst和srilm的安装编译似乎有问题,暂时还没有benchmark结果。有结果之后,我很乐意写文章介绍这个工作:) |
@robin1001 @Mddct @xingchensong 最新 benchmark,attention 解码精度达标,ctc解码精度有偏差,其中4-GPU结果来源于 https://github.com/wenet-e2e/wenet/blob/main/examples/aishell/s0/README.md Conformer Result
|
我这在Aishell-1上刚好有个实验结果,可以参考: train config: examples/aishell/s0/conf/train_conformer.yaml (7ce2126) |
感谢分享!请问你是在此分支上完成训练的吗? |
是的 |
wenet/utils/executor.py
Outdated
@@ -106,7 +109,8 @@ def train(self, model, optimizer, scheduler, train_data_loader, | |||
"lrs": | |||
[group['lr'] for group in optimizer.param_groups] | |||
}) | |||
save_model(model, info_dict) | |||
if self.step % 100 == 0: | |||
save_model(model, info_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个self.step % 100 和line94功能重复
if "cuda" in args.device: | ||
torch.cuda.set_device(local_rank) | ||
elif "npu" in args.device and TORCH_NPU_AVAILABLE: | ||
torch.npu.set_device(local_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再来个 else 处理下异常情况
可以分享一下你的npu推理代码或者脚本吗,我使用npu推理,速度比cpu还要慢 |
@285220927 我直接使用 stage5 进行推理 |
我用的也是这个脚本,批量的推理,batch_size设置的16,慢到我无法忍受了 |
related to #2513